{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train [U-Net](https://arxiv.org/abs/1505.04597)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The [`nn`](../nn) module defines everything related to neural network. \n",
    "The default assumption for these neural networks is they need [tf.data.Dataset](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) as input. This reader can be created by using [`reader`](../reader) module"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training a network can be splitted into the following steps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# this is for path management in jupyter notebook\n",
    "# not necessary if you're running in terminal or other IDEs\n",
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lab/anaconda3/envs/tf-aml/lib/python3.6/site-packages/h5py/__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "# imports and parameter settings\n",
    "import tensorflow as tf\n",
    "from nn import unet\n",
    "class_num = 2                 # class number in ground truth\n",
    "patch_size = (572, 572)       # input patch size\n",
    "lr = 1e-4                     # start learning rate\n",
    "ds = 60                       # #epochs before lr decays\n",
    "dr = 0.1                      # lr will decay to lr*dr\n",
    "epochs = 1                    # #epochs to train\n",
    "bs = 5                        # batch size\n",
    "suffix = 'test'               # user customize name for the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# define network\n",
    "tf.reset_default_graph()\n",
    "unet = unet.UNet(class_num, patch_size, suffix=suffix, learn_rate=lr, decay_step=ds, decay_rate=dr,\n",
    "                 epochs=epochs, batch_size=bs)\n",
    "overlap = unet.get_overlap()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make Collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# imports and parameter settings\n",
    "import numpy as np\n",
    "from collection import collectionMaker, collectionEditor\n",
    "ds_name = 'Inria'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gt_d255 might already exist, skip replacement\n"
     ]
    }
   ],
   "source": [
    "cm = collectionMaker.read_collection(raw_data_path=r'/media/ei-edl01/data/uab_datasets/inria/data/Original_Tiles',\n",
    "                                     field_name='austin,chicago,kitsap,tyrol-w,vienna', # use all cities\n",
    "                                     field_id=','.join(str(i) for i in range(37)), # use all tiles\n",
    "                                     rgb_ext='RGB',\n",
    "                                     gt_ext='GT',\n",
    "                                     file_ext='tif',\n",
    "                                     force_run=False,\n",
    "                                     clc_name=ds_name)\n",
    "gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1/255, ['GT', 'gt_d255']).\\\n",
    "    run(force_run=False, file_ext='png', d_type=np.uint8,)\n",
    "cm.replace_channel(gt_d255.files, True, ['GT', 'gt_d255'])\n",
    "chan_mean = cm.meta_data['chan_mean'][:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=========================================Inria==========================================\n",
      "raw_data_path: /media/ei-edl01/data/uab_datasets/inria/data/Original_Tiles\n",
      "field_name: ['austin', 'chicago', 'kitsap', 'tyrol-w', 'vienna']\n",
      "field_id: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36']\n",
      "clc_name: Inria\n",
      "tile_dim: (5000, 5000)\n",
      "chan_mean: [103.2341876  108.95194825 100.14192002]\n",
      "Source file: *RGB*.tif\n",
      "GT file: *gt_d255*.png\n",
      "========================================================================================\n"
     ]
    }
   ],
   "source": [
    "cm.print_meta_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get Training and Validation Files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 6~36 for training\n",
    "file_list_train = cm.load_files(field_id=','.join(str(i) for i in range(6, 37)), field_ext='RGB,gt_d255')\n",
    "# 1~6 for validation\n",
    "file_list_valid = cm.load_files(field_id=','.join(str(i) for i in range(6)), field_ext='RGB,gt_d255')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Patch Extraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# imports and parameter settings\n",
    "from preprocess import patchExtractor\n",
    "tile_size = (5000, 5000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "patch_list_train = patchExtractor.PatchExtractor(patch_size, tile_size, ds_name+'_train', overlap, overlap//2).\\\n",
    "    run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()\n",
    "patch_list_valid = patchExtractor.PatchExtractor(patch_size, tile_size, ds_name+'_valid', overlap, overlap//2).\\\n",
    "    run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Data Reader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# imports and parameter settings\n",
    "from reader import dataReaderSegmentation, reader_utils\n",
    "valid_mult = 5 # validation can have a larger batch size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_init_op, valid_init_op, reader_op = dataReaderSegmentation.DataReaderSegmentationTrainValid(\n",
    "    patch_size, patch_list_train, patch_list_valid, batch_size=bs, chan_mean=chan_mean,\n",
    "    aug_func=[reader_utils.image_flipping, reader_utils.image_rotating], # augmentation function for training data\n",
    "    random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=valid_mult).read_op()\n",
    "feature, label = reader_op"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[`hook`](../nn/hook.py) is used here to monitor the training/validation process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# imports and parameter settings\n",
    "import ersaPath\n",
    "from nn import hook, nn_utils\n",
    "sfn = 32 # start fileter number for the U-Net\n",
    "n_train = 1000 # #samples per epoch\n",
    "n_valid = 1000//bs//valid_mult # #steps to run in validation\n",
    "gpu = 1\n",
    "verb_step = 200 # print out verbose messages every 200 steps\n",
    "nn_utils.set_gpu(gpu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lab/anaconda3/envs/tf-aml/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py:100: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    }
   ],
   "source": [
    "unet.create_graph(feature, sfn)\n",
    "unet.compile(feature, label, n_train, n_valid, patch_size, ersaPath.PATH['model'], par_dir='test', loss_type='xent')\n",
    "train_hook = hook.ValueSummaryHook(verb_step, [unet.loss, unet.lr_op], value_names=['train_loss', 'learning_rate'],\n",
    "                                   print_val=[0]) # print loss, write loss and lr every 200 step\n",
    "# print&write validation loss every epoch\n",
    "valid_loss_hook = hook.ValueSummaryHook(unet.get_epoch_step(), [unet.loss],\n",
    "                                        value_names=['valid_loss'], log_time=True, run_time=unet.n_valid)\n",
    "# print&write validation IoU every epoch\n",
    "valid_iou_hook = hook.IoUSummaryHook(unet.get_epoch_step(), unet.loss_iou, log_time=True, run_time=unet.n_valid,\n",
    "                                     cust_str='\\t')\n",
    "# write validation images every epoch\n",
    "image_hook = hook.ImageValidSummaryHook(unet.get_epoch_step(), unet.valid_images, feature, label, unet.pred,\n",
    "                                        nn_utils.image_summary, img_mean=chan_mean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 200\ttrain_loss 0.464\n",
      "Eval @ Epoch 0 Step 200\tvalid_loss 0.411, Duration: 107.648\n",
      "\tStep 200\tIoU 0.205, Duration: 138.107\n"
     ]
    }
   ],
   "source": [
    "unet.train(train_hooks=[train_hook], valid_hooks=[valid_loss_hook, valid_iou_hook, image_hook],\n",
    "           train_init=train_init_op, valid_init=valid_init_op)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The hoooks also write data into [tensorboard](https://www.tensorflow.org/guide/summaries_and_tensorboard)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
